Skip to content

[Kernel] Use pre-allocated output buffer for triton kernel fused_experts#29219

Merged
jeejeelee merged 4 commits intovllm-project:mainfrom
xyang16:triton
Nov 26, 2025
Merged

[Kernel] Use pre-allocated output buffer for triton kernel fused_experts#29219
jeejeelee merged 4 commits intovllm-project:mainfrom
xyang16:triton

Conversation

@xyang16
Copy link
Contributor

@xyang16 xyang16 commented Nov 22, 2025

Purpose

This PR is to use pre-allocated output buffer for triton kernel matmal_ogs

  • Fix N by overriding moe_problem_size() function in OAITritonExperts, because the super class moe_problem_size expects N to be the second dimension of w1, see here. But triton kernels expect N to be the third dimension of w1. This will cause N assigned the value of K incorrectly for triton.
  • Allocate intermediate_cache13 (shape [M * topk, N // 2]) to be the output buffer of first matmal_ogs
  • Allocate output (shape [M, K]) to be the output buffer of second matmal_ogs
  • Add batch_dim to output buffer because matmul_ogs expects 3D output, see here.

Test Plan

pytest -s -v tests/kernels/moe/test_gpt_oss_triton_kernels.py 

Test Result

Unit test passed

Accuracy Testing

  • gpt-oss-20b
vllm serve openai/gpt-oss-20b --tensor-parallel-size 1 --max-num-seqs=16 
Writing report to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251121_182105.html
{'chars': np.float64(66.06565656565657), 'chars:std': np.float64(235.89106420986758), 'score': np.float64(0.5681818181818182), 'score:std': np.float64(0.4953294254023493)}
Writing results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251121_182105.json
Writing all results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251121_182105_allresults.json
[{'eval_name': 'gpqa', 'model_name': 'openai__gpt-oss-20b-low_temp1.0_20251121_182105', 'metric': 0.5681818181818182}]
  • gpt-oss-20b deepep
VLLM_ALL2ALL_BACKEND="deepep_high_throughput" vllm serve openai/gpt-oss-20b --tensor-parallel-size 1 --data-parallel-size 2 --enable-expert-parallel --no-enable-prefix-caching
Writing report to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251121_184437.html
{'chars': np.float64(71.08207070707071), 'chars:std': np.float64(268.30453841690064), 'score': np.float64(0.5707070707070707), 'score:std': np.float64(0.4949752621616814)}
Writing results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251121_184437.json
Writing all results to /tmp/gpqa_openai__gpt-oss-20b-low_temp1.0_20251121_184437_allresults.json
[{'eval_name': 'gpqa', 'model_name': 'openai__gpt-oss-20b-low_temp1.0_20251121_184437', 'metric': 0.5707070707070707}]

Benchmark

vllm serve openai/gpt-oss-20b \
  --tensor-parallel-size 1 \
  --max-num-seqs 16
vllm bench serve \
  --model openai/gpt-oss-20b \
  --dataset-name sharegpt \
  --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
  --max-concurrency 16 \
  --num-prompts 1000 \
  --num-warmups 60 \
  --ignore-eos

Baseline:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  103.15    
Total input tokens:                      215312    
Total generated tokens:                  199033    
Request throughput (req/s):              9.69      
Output token throughput (tok/s):         1929.46   
Peak output token throughput (tok/s):    2104.00   
Peak concurrent requests:                33.00     
Total Token throughput (tok/s):          4016.74   
---------------Time to First Token----------------
Mean TTFT (ms):                          30.38     
Median TTFT (ms):                        20.28     
P99 TTFT (ms):                           579.15    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          8.03      
Median TPOT (ms):                        8.01      
P99 TPOT (ms):                           9.32      
---------------Inter-token Latency----------------
Mean ITL (ms):                           8.04      
Median ITL (ms):                         7.66      
P99 ITL (ms):                            17.10     
==================================================

PR:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  96.91     
Total input tokens:                      215312    
Total generated tokens:                  199033    
Request throughput (req/s):              10.32     
Output token throughput (tok/s):         2053.88   
Peak output token throughput (tok/s):    2230.00   
Peak concurrent requests:                35.00     
Total Token throughput (tok/s):          4275.74   
---------------Time to First Token----------------
Mean TTFT (ms):                          19.89     
Median TTFT (ms):                        17.34     
P99 TTFT (ms):                           86.70     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.59      
Median TPOT (ms):                        7.55      
P99 TPOT (ms):                           8.64      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.59      
Median ITL (ms):                         7.37      
P99 ITL (ms):                            16.13     
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

cc @varun-sundar-rabindranath

Loading
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

gpt-oss Related to GPT-OSS models ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

5 participants